"""BiGym Robot."""
import dataclasses
from pathlib import Path
from typing import Optional, Dict, Union, Iterable

import numpy as np
from dm_control import mjcf
from mojo import Mojo
from mojo.elements import Body, Site, MujocoElement, Geom
from mojo.elements.consts import JointType
from mujoco_utils import mjcf_utils
from pyquaternion import Quaternion

from bigym.action_modes import (
    ActionMode,
    TorqueActionMode,
    JointPositionActionMode,
    PelvisDof,
)
from bigym.const import (
    ASSETS_PATH,
    HandSide,
    TOLERANCE_LINEAR,
    TOLERANCE_ANGULAR,
    WORLD_MODEL,
)
from bigym.envs.props.prop import Prop
from bigym.utils.dof import Dof
from bigym.utils.physics_utils import (
    is_target_reached,
    get_critical_damping_from_stiffness,
    has_collided_collections,
    get_colliders,
)

STIFFNESS_POSITION_ACTUATOR = 300
STIFFNESS_PELVIS_DOF = 10000
STIFFNESS_PELVIS_DOF_Z = STIFFNESS_PELVIS_DOF * 10

RANGE_PELVIS_DOF_Z = (0.4, 1.0)

FLOATING_BASE_DOFS: Dict[PelvisDof, Dof] = {
    PelvisDof.X: Dof(
        joint_type=JointType.SLIDE,
        axis=(1, 0, 0),
        stiffness=STIFFNESS_PELVIS_DOF,
    ),
    PelvisDof.Y: Dof(
        joint_type=JointType.SLIDE,
        axis=(0, 1, 0),
        stiffness=STIFFNESS_PELVIS_DOF,
    ),
    PelvisDof.Z: Dof(
        joint_type=JointType.SLIDE,
        axis=(0, 0, 1),
        joint_range=RANGE_PELVIS_DOF_Z,
        action_range=RANGE_PELVIS_DOF_Z,
        stiffness=STIFFNESS_PELVIS_DOF_Z,
    ),
    PelvisDof.RZ: Dof(
        joint_type=JointType.HINGE,
        axis=(0, 0, 1),
        stiffness=STIFFNESS_PELVIS_DOF,
    ),
}

WRIST_JOINTS = {
    HandSide.LEFT: "left_wrist",
    HandSide.RIGHT: "right_wrist",
}

ARM_LINKS = {
    HandSide.LEFT: [
        "left_shoulder_pitch_link",
        "left_shoulder_roll_link",
        "left_shoulder_yaw_link",
        "left_elbow_link",
    ],
    HandSide.RIGHT: [
        "right_shoulder_pitch_link",
        "right_shoulder_roll_link",
        "right_shoulder_yaw_link",
        "right_elbow_link",
    ],
}

LEG_LINKS = {
    HandSide.LEFT: [
        "left_hip_yaw_link",
        "left_hip_roll_link",
        "left_hip_pitch_link",
        "left_knee_link",
        "left_ankle_link",
    ],
    HandSide.RIGHT: [
        "right_hip_yaw_link",
        "right_hip_roll_link",
        "right_hip_pitch_link",
        "right_knee_link",
        "right_ankle_link",
    ],
}

HAND_SITES = {
    HandSide.LEFT: "left_end_effector",
    HandSide.RIGHT: "right_end_effector",
}

HUMANOID_ACTUATORS = {
    "left_hip_yaw": False,
    "left_hip_roll": False,
    "left_hip_pitch": False,
    "left_knee": False,
    "left_ankle": False,
    "right_hip_yaw": False,
    "right_hip_roll": False,
    "right_hip_pitch": False,
    "right_knee": False,
    "right_ankle": False,
    "torso": False,
    "left_shoulder_pitch": True,
    "left_shoulder_roll": True,
    "left_shoulder_yaw": True,
    "left_elbow": True,
    "left_wrist": True,
    "right_shoulder_pitch": True,
    "right_shoulder_roll": True,
    "right_shoulder_yaw": True,
    "right_elbow": True,
    "right_wrist": True,
}

GRIPPER_ACTUATOR = "fingers_actuator"

ROBOT_MODEL = ASSETS_PATH / "h1/h1.xml"
FLOATING_BASE_MODEL = ASSETS_PATH / "h1/h1_floating_base.xml"
FLOATING_BASE_JOINT = "h1_floating_base"


class AnimatedLegs:
    """Robot legs animation for the floating base mode."""

    _LINKS_COUNT = 3
    _L1 = 0.4
    _L2 = 0.4
    _2_L1_L2 = 2 * _L1 * _L2
    _L1SQ_L2SQ = _L1 * _L1 + _L2 * _L2

    _HIPS_OFFSET = 0.1742
    _ANKLE_HEIGHT = 0.08

    _STEP_HEIGHT = 0.04
    _STEP_DURATION = 0.5

    def __init__(self, mojo: Mojo, pelvis: Body):
        """Init."""
        self._mojo = mojo
        self._pelvis = pelvis

        # Remove original legs and pelvis mesh
        for side in HandSide:
            Body.get(self._mojo, LEG_LINKS[side][0], pelvis).mjcf.remove()
        self._pelvis.geoms[0].mjcf.remove()

        floating_base_site = Site.create(
            self._mojo, self._pelvis, size=np.array([0.001, 0.001, 0.001])
        )
        self._mojo.load_model(
            FLOATING_BASE_MODEL, floating_base_site, on_loaded=self._on_loaded
        )

    def step(self, pelvis_z: float, is_moving: bool = True):
        """Step animation."""
        scale = 1 if is_moving else 0
        # Setting collision meshes
        solution_collision_min = self._solve(pelvis_z, 0)
        solution_collision_max = self._solve(pelvis_z, self._STEP_HEIGHT * scale)
        self._set_leg_state(self._collision_min[HandSide.LEFT], solution_collision_min)
        self._set_leg_state(self._collision_min[HandSide.RIGHT], solution_collision_min)
        self._set_leg_state(self._collision_max[HandSide.LEFT], solution_collision_max)
        self._set_leg_state(self._collision_max[HandSide.RIGHT], solution_collision_max)
        # Settings visual meshes
        solution_left = self._solve(pelvis_z, self._get_offset(0, scale))
        solution_right = self._solve(pelvis_z, self._get_offset(np.pi / 2, scale))
        self._set_leg_state(self._visual[HandSide.LEFT], solution_left)
        self._set_leg_state(self._visual[HandSide.RIGHT], solution_right)

    def _on_loaded(self, model: mjcf.RootElement):
        floating_base = MujocoElement(self._mojo, model)

        self._visual: dict[HandSide, list[Body]] = {
            HandSide.LEFT: [
                Body.get(self._mojo, link, floating_base)
                for link in LEG_LINKS[HandSide.LEFT][-self._LINKS_COUNT :]
            ],
            HandSide.RIGHT: [
                Body.get(self._mojo, link, floating_base)
                for link in LEG_LINKS[HandSide.RIGHT][-self._LINKS_COUNT :]
            ],
        }

        self._collision_min: dict[HandSide, list[Body]] = {
            HandSide.LEFT: [
                Body.get(self._mojo, f"{link}_collision_min", floating_base)
                for link in LEG_LINKS[HandSide.LEFT][-self._LINKS_COUNT :]
            ],
            HandSide.RIGHT: [
                Body.get(self._mojo, f"{link}_collision_min", floating_base)
                for link in LEG_LINKS[HandSide.RIGHT][-self._LINKS_COUNT :]
            ],
        }
        self._collision_max: dict[HandSide, list[Body]] = {
            HandSide.LEFT: [
                Body.get(self._mojo, f"{link}_collision_max", floating_base)
                for link in LEG_LINKS[HandSide.LEFT][-self._LINKS_COUNT :]
            ],
            HandSide.RIGHT: [
                Body.get(self._mojo, f"{link}_collision_max", floating_base)
                for link in LEG_LINKS[HandSide.RIGHT][-self._LINKS_COUNT :]
            ],
        }

    def _get_offset(self, shift: float, scale: float) -> float:
        scale = np.clip(scale, 0, 1)
        t = self._mojo.physics.time() % self._STEP_DURATION
        return (
            scale
            * self._STEP_HEIGHT
            * (1 - np.abs(np.sin(2 * np.pi * (t / self._STEP_DURATION) + shift)))
        )

    def _set_leg_state(self, leg: list[Body], state: np.ndarray):
        for angle, link in zip(state, leg):
            link = self._mojo.physics.bind(link.mjcf)
            quat = Quaternion(axis=[0, 1, 0], radians=angle)
            link.quat = quat.elements

    def _solve(self, pelvis_z: float, offset: float = 0.0) -> np.ndarray:
        hip_z = pelvis_z - self._HIPS_OFFSET
        r = hip_z - self._ANKLE_HEIGHT - offset
        r = np.clip(r, 0, self._L1 + self._L2)
        angle_knee = np.arccos((self._L1SQ_L2SQ - r * r) / self._2_L1_L2)
        angle_hip = -(np.pi - angle_knee) / 2
        angle_knee = np.pi - angle_knee
        angle_ankle = angle_hip
        return np.array([angle_hip, angle_knee, angle_ankle])


class RobotFloatingBase:
    """Floating base of the robot to simplify control."""

    DELTA_RANGE_POS: tuple[float, float] = (-0.01, 0.01)
    DELTA_RANGE_ROT: tuple[float, float] = (-0.05, 0.05)

    def __init__(
        self,
        pelvis: Body,
        floating_dofs: list[PelvisDof],
        model: mjcf.RootElement,
        mojo: Mojo,
    ):
        """Init."""
        self._pelvis: Body = pelvis
        self._mojo = mojo
        self._position_actuators: list[Optional[mjcf.Element]] = [None, None, None]
        self._rotation_actuators: list[Optional[mjcf.Element]] = [None, None, None]

        for i, floating_dof in enumerate(floating_dofs):
            dof = FLOATING_BASE_DOFS[floating_dof]
            joint = self._pelvis.mjcf.add(
                "joint",
                type=dof.joint_type.value,
                name=floating_dof.value,
                axis=dof.axis,
            )
            if dof.joint_range:
                joint.limited = True
                joint.range = dof.joint_range
            else:
                joint.limited = False

            actuator = model.actuator.insert(
                "position",
                position=i,
                name=floating_dof.value,
                joint=joint,
                kp=dof.stiffness,
            )
            if dof.action_range:
                actuator.ctrllimited = True
                actuator.ctrlrange = dof.action_range
            else:
                actuator.ctrllimited = False

            self._add_actuator(
                positional=dof.joint_type == JointType.SLIDE,
                axis=np.array(dof.axis),
                actuator=actuator,
            )

        self._animated_legs = AnimatedLegs(self._mojo, self._pelvis)

        self._accumulated_actions: np.ndarray = np.zeros(len(self.all_actuators))
        self._last_action: np.ndarray = np.zeros(len(self.all_actuators))

    def reset(self, position: np.ndarray, quaternion: np.ndarray):
        """Set position and orientation of the floating base."""
        self._accumulated_actions *= 0
        self._last_action *= 0

        self._set_position(position)
        self._set_quaternion(quaternion)
        self._animated_legs.step(self._pelvis_z, False)

    def get_action_bounds(self) -> list[tuple[float, float]]:
        """Get action bounds of all actuators."""
        bounds = []
        for actuator in self._position_actuators:
            if actuator:
                bounds.append(self.DELTA_RANGE_POS)
        for actuator in self._rotation_actuators:
            if actuator:
                bounds.append(self.DELTA_RANGE_ROT)
        return bounds

    def set_control(self, control: np.ndarray):
        """Set control of all actuators."""
        self._accumulated_actions += self._last_action
        self._last_action = control.copy()

        index = 0
        for actuator in self._position_actuators:
            if not actuator:
                continue
            self._mojo.physics.bind(actuator).ctrl += control[index]
            index += 1
        for actuator in self._rotation_actuators:
            if not actuator:
                continue
            self._mojo.physics.bind(actuator).ctrl += control[index]
            index += 1
        self._animated_legs.step(self._pelvis_z)

    @property
    def is_target_reached(self) -> bool:
        """Check if the target state of all actuators is reached."""
        for actuator in self._position_actuators:
            if actuator and not is_target_reached(
                actuator, self._mojo.physics, TOLERANCE_LINEAR
            ):
                return False
        for actuator in self._rotation_actuators:
            if actuator and not is_target_reached(
                actuator, self._mojo.physics, TOLERANCE_ANGULAR
            ):
                return False
        return True

    @property
    def dof_amount(self) -> int:
        """Get number of actuated DOF."""
        return len(self.all_actuators)

    @property
    def qpos(self) -> np.ndarray:
        """Get positions of actuated joints."""
        qpos = []
        for actuator in self._position_actuators:
            if actuator:
                qpos.append(self._mojo.physics.bind(actuator.joint).qpos.item())
        for actuator in self._rotation_actuators:
            if actuator:
                qpos.append(self._mojo.physics.bind(actuator.joint).qpos.item())
        return np.array(qpos, np.float32)

    @property
    def qvel(self) -> np.ndarray:
        """Get velocities of actuated joints."""
        qpos = []
        for actuator in self._position_actuators:
            if actuator:
                qpos.append(self._mojo.physics.bind(actuator.joint).qvel.item())
        for actuator in self._rotation_actuators:
            if actuator:
                qpos.append(self._mojo.physics.bind(actuator.joint).qvel.item())
        return np.array(qpos, np.float32)

    @property
    def get_accumulated_actions(self) -> np.ndarray:
        """Get accumulated actions since last reset."""
        return np.array(self._accumulated_actions, np.float32)

    @property
    def all_actuators(self) -> list[mjcf.Element]:
        """Get all actuators."""
        return [a for a in self._position_actuators if a] + [
            a for a in self._rotation_actuators if a
        ]

    @property
    def position_actuators(self) -> list[Optional[mjcf.Element]]:
        """Get all position actuators."""
        return self._position_actuators

    @property
    def rotation_actuators(self) -> list[Optional[mjcf.Element]]:
        """Get all rotation actuators."""
        return self._rotation_actuators

    @property
    def _pelvis_z(self) -> float:
        if self._position_actuators[2]:
            joint = self._mojo.physics.bind(self._position_actuators[2].joint)
            return float(joint.qpos)
        else:
            pelvis = self._mojo.physics.bind(self._pelvis.mjcf)
            return float(pelvis.pos[2])

    def _set_position(self, position: np.ndarray):
        self._set_value(True, position)

    def _set_quaternion(self, quaternion: np.ndarray):
        rotation = np.flip(np.array(Quaternion(quaternion).yaw_pitch_roll))
        self._set_value(False, rotation)

    def _set_value(self, position: bool, values: np.ndarray):
        actuators = self._position_actuators if position else self._rotation_actuators
        assert len(values) == len(actuators)
        for i, value, actuator in zip(range(len(values)), values, actuators):
            if actuator:
                bound_joint = self._mojo.physics.bind(actuator.joint)
                bound_joint.qpos = value
                bound_joint.qvel *= 0
                bound_joint.qacc *= 0
                self._mojo.physics.bind(actuator).ctrl = value
            else:
                pelvis = self._mojo.physics.bind(self._pelvis.mjcf)
                if position:
                    pelvis.pos[i] = value
                else:
                    pass

    def _add_actuator(self, positional: bool, axis: np.ndarray, actuator: mjcf.Element):
        """Add floating base actuator."""
        actuator_index = np.argmax(axis)
        actuators = self._position_actuators if positional else self._rotation_actuators
        actuators[actuator_index] = actuator


@dataclasses.dataclass
class GripperModel:
    """Robot gripper."""

    asset: Path
    pads: list[str]


class Robot:
    """H1 Robot."""

    GRIPPER_RANGE = (0, 1)
    DELTA_LIMB_RANGE = (-0.1, 0.1)

    _PELVIS = "pelvis"
    _LIGHT = "light"
    _FREEJOINT = "freejoint"
    _PINCH = "pinch"
    _GRIPPER_OFFSET = np.array([np.pi / 2, np.pi / 2, 0])

    _WRIST_DOF = Dof(
        joint_type=JointType.HINGE,
        axis=(1, 0, 0),
        joint_range=(-1.5708, 1.5708),
        action_range=(-18, 18),
    )

    _DEFAULT_GRIPPER = GripperModel(
        ASSETS_PATH / "robotiq_2f85/2f85.xml",
        ["right_pad1", "right_pad2", "left_pad1", "left_pad2"],
    )
    _STABLE_GRIPPER = GripperModel(
        ASSETS_PATH / "robotiq_2f85/2f85_stable.xml",
        [
            "right_pad1",
            "right_pad2",
            "right_pad3",
            "right_pad4",
            "left_pad1",
            "left_pad2",
            "left_pad3",
            "left_pad4",
        ],
    )

    def __init__(
        self,
        action_mode: ActionMode,
        mojo: Optional[Mojo] = None,
        stable_gripper: bool = False,
    ):
        """Init."""
        self._action_mode = action_mode
        self._mojo = mojo or Mojo(WORLD_MODEL)
        self._gripper_model = (
            self._STABLE_GRIPPER if stable_gripper else self._DEFAULT_GRIPPER
        )
        self._body = self._mojo.load_model(ROBOT_MODEL, on_loaded=self._on_loaded)
        self._joints = [
            j for j in self._body.joints if j.mjcf.name != FLOATING_BASE_JOINT
        ]

        if not self._action_mode.floating_base:
            self._body.set_kinematic(True)

        # Bind robot to action mode
        self._action_mode.bind_robot(self, self._mojo)

    @property
    def action_mode(self) -> ActionMode:
        """Get action mode."""
        return self._action_mode

    @property
    def pelvis(self) -> Body:
        """Get pelvis."""
        return self._pelvis

    @property
    def limb_actuators(self) -> list[mjcf.Element]:
        """Get all limb actuators."""
        return self._limb_actuators

    @property
    def gripper_actuators(self) -> list[mjcf.Element]:
        """Get all gripper actuators."""
        return self._gripper_actuators

    @property
    def floating_base(self) -> Optional[RobotFloatingBase]:
        """Get floating base."""
        return self._floating_base

    @property
    def qpos(self) -> np.ndarray:
        """Get positions of all joints."""
        return np.array(
            [joint.get_joint_position() for joint in self._joints], np.float32
        )

    @property
    def qpos_grippers(self) -> np.ndarray:
        """Get current state of gripper actuators."""
        qpos = []
        for actuator in self._gripper_actuators:
            actuator = self._mojo.physics.bind(actuator)
            control = actuator.ctrl.item()
            control = np.round(
                (control - actuator.ctrlrange[0])
                / (actuator.ctrlrange[0] + actuator.ctrlrange[1])
            )
            control = np.clip(control, self.GRIPPER_RANGE[0], self.GRIPPER_RANGE[1])
            qpos.append(control)
        return np.array(qpos, np.float32)

    @property
    def qpos_actuated(self) -> np.ndarray:
        """Get positions of actuated joints."""
        qpos = []
        if self.floating_base:
            qpos.extend(self._floating_base.qpos)
        for actuator in self._limb_actuators:
            qpos.append(self._mojo.physics.bind(actuator.joint).qpos.item())
        qpos.extend(self.qpos_grippers)
        return np.array(qpos, np.float32)

    @property
    def qvel(self) -> np.ndarray:
        """Get velocities of all joints."""
        return np.array(
            [joint.get_joint_velocity() for joint in self._joints], np.float32
        )

    @property
    def qvel_actuated(self) -> np.ndarray:
        """Get velocities of actuated joints."""
        qvel = []
        if self.floating_base:
            qvel.extend(self._floating_base.qvel)
        for actuator in self._limb_actuators:
            qvel.append(self._mojo.physics.bind(actuator.joint).qvel.item())
        for actuator in self._gripper_actuators:
            qvel.append(self._mojo.physics.bind(actuator.tendon).velocity.item())
        return np.array(qvel, np.float32)

    def get_hand_pos(self, side: HandSide) -> np.ndarray:
        """Get position of robot hand site."""
        return self._mojo.physics.bind(self._hand_sites[side].mjcf).xpos.copy()

    def get_gripper_pos(self, side: HandSide) -> np.ndarray:
        """Get position of robot gripper end effector."""
        return self._mojo.physics.bind(self._gripper_sites[side].mjcf).xpos.copy()

    def is_gripper_holding_object(
        self, other: Union[Geom, Iterable[Geom], Prop], side: HandSide
    ) -> bool:
        """Check if the gripper is holding an object."""
        other_colliders = get_colliders(other)
        return has_collided_collections(
            self._mojo.physics, self._gripper_pads[side], other_colliders
        )

    def set_pose(self, position: np.ndarray, orientation: np.ndarray):
        """Instantly set pose of the robot pelvis."""
        if self._action_mode.floating_base:
            self._floating_base.reset(position, orientation)
        else:
            self._pelvis.set_position(position)
            self._pelvis.set_quaternion(orientation)

    def set_gripper_control(
        self, actuator: mjcf.Element, ctrl: float, discrete_control: bool = True
    ):
        """Set control of gripper actuator."""
        ctrl = np.clip(ctrl, *self.GRIPPER_RANGE)
        if discrete_control:
            ctrl = round(ctrl)
        ctrl = np.interp(ctrl, self.GRIPPER_RANGE, actuator.ctrlrange)
        self._mojo.physics.bind(actuator).ctrl = ctrl

    def get_gripper_control_range(self):
        """Get gripper control range."""
        return self.GRIPPER_RANGE

    def get_limb_control_range(self, actuator: mjcf.Element, absolute: bool):
        """Get control ange of the limb actuator."""
        return actuator.ctrlrange if absolute else self.DELTA_LIMB_RANGE

    def _on_loaded(self, model: mjcf.RootElement):
        # Remove all lights from the robot model
        lights = model.find_all(self._LIGHT)
        for light in lights:
            light.remove()

        self._gripper_sites: dict[HandSide, Site] = {}
        self._gripper_pads: dict[HandSide, list[Geom]] = {}
        self._hand_sites: dict[HandSide, Site] = {}
        for side in HandSide:
            if self._action_mode.wrist_dof:
                self._add_wrist(model, side)
            self._attach_gripper(model, side)

        self._pelvis: Body = Body.get(
            self._mojo, self._PELVIS, MujocoElement(self._mojo, model)
        )

        # Always remove freejoint
        if self._pelvis.is_kinematic():
            self._pelvis.set_kinematic(False)

        # Reset default position of pelvis
        self._pelvis.mjcf.pos = np.zeros(3)
        self._pelvis.mjcf.euler = np.zeros(3)

        # List of new positional actuators
        # It will be used to tune damping of controlled joints
        new_positional_actuators: list[mjcf.Element] = []

        # Setup floating base
        self._floating_base: Optional[RobotFloatingBase] = None
        if self._action_mode.floating_base:
            self._floating_base = RobotFloatingBase(
                self._pelvis, self._action_mode.floating_dofs, model, self._mojo
            )
            new_positional_actuators.extend(self._floating_base.all_actuators)

        # Assign limb actuators
        self._limb_actuators: list[mjcf.Element] = []
        for actuator in model.actuator.motor:
            if actuator.full_identifier in HUMANOID_ACTUATORS.keys():
                if (
                    self._action_mode.floating_base
                    and not HUMANOID_ACTUATORS[actuator.name]
                ):
                    if actuator.joint:
                        actuator.joint.remove()
                    actuator.remove()
                    continue
                if isinstance(self._action_mode, TorqueActionMode):
                    self._limb_actuators.append(actuator)
                elif isinstance(self._action_mode, JointPositionActionMode):
                    actuator_name = actuator.name
                    actuator_joint = actuator.joint
                    actuator.remove()
                    actuator = model.actuator.add(
                        "position",
                        name=actuator_name,
                        joint=actuator_joint,
                        kp=STIFFNESS_POSITION_ACTUATOR,
                        ctrlrange=actuator_joint.range,
                    )
                    self._limb_actuators.append(actuator)
                    new_positional_actuators.append(actuator)

        joints = model.find_all("joint") or []
        # Sort limb actuators according to the joints tree
        self._limb_actuators.sort(key=lambda a: joints.index(a.joint) if a.joint else 0)

        # Assign gripper actuators
        self._gripper_actuators: list[mjcf.Element] = []
        for actuator in model.actuator.general:
            if actuator.name == GRIPPER_ACTUATOR:
                self._gripper_actuators.append(actuator)

        # Temporary instance of physics to simplify model editing
        physics_tmp = mjcf.Physics.from_mjcf_model(model)

        # Fix joint damping
        for actuator in new_positional_actuators:
            damping = get_critical_damping_from_stiffness(
                actuator.kp, actuator.joint.full_identifier, physics_tmp
            )
            actuator.joint.damping = damping

    def _attach_gripper(self, model: mjcf.RootElement, side: HandSide):
        def _on_gripper_loaded(gripper_model: mjcf.RootElement):
            gripper_model.model += f"_{side.value.lower()}"
            gripper_element = MujocoElement(self._mojo, gripper_model)
            self._gripper_sites[side] = Site.get(
                self._mojo, self._PINCH, gripper_element
            )
            self._gripper_pads[side] = [
                Geom.get(self._mojo, pad, gripper_element)
                for pad in self._gripper_model.pads
            ]

        hand_site: Site = Site.get(
            self._mojo, HAND_SITES[side], MujocoElement(self._mojo, model)
        )
        self._hand_sites[side] = hand_site

        gripper = self._mojo.load_model(
            str(self._gripper_model.asset),
            hand_site,
            on_loaded=_on_gripper_loaded,
        )
        gripper.mjcf.euler = self._GRIPPER_OFFSET
        self._mojo.mark_dirty()

        return gripper

    def _add_wrist(self, model: mjcf.RootElement, side: HandSide):
        join_name = WRIST_JOINTS[side]
        site_name = HAND_SITES[side]

        site = mjcf_utils.safe_find(model, "site", site_name)
        site_pos = site.pos
        site_parent = site.parent
        site.remove()

        wrist = site_parent.add("body", name=f"{join_name}_link", pos=site_pos)
        wrist.add(
            "inertial", pos="0 0 0", mass="1e-15", diaginertia="1e-15 1e-15 1e-15"
        )
        wrist.add("site", name=site_name)
        joint = wrist.add(
            "joint",
            type=self._WRIST_DOF.joint_type.value,
            name=join_name,
            axis=self._WRIST_DOF.axis,
            range=self._WRIST_DOF.joint_range,
        )
        model.actuator.add(
            "motor",
            name=join_name,
            joint=joint,
            ctrlrange=self._WRIST_DOF.action_range,
        )
